function [X, iter, min_cost, cost_vec, e] = fista_general(grad, Xinit, L, opts, cost_F)
% * A Fast Iterative Shrinkage-Thresholding Algorithm for
% Linear Inverse Problems.
% * Solve the problem: X = arg min_X F(X) = f(X) + lambda*g(X) where:
%   - X: variable, can be a matrix.
%   - f(X): a smooth convex function with continuously differentiable
%       with Lipschitz continuous gradient `L(f)` (Lipschitz constant of
%       the gradient of `f`).
%  INPUT:
%       grad   : a function calculating gradient of f(X) given X.
%       Xinit  : a matrix -- initial guess.
%       L      : a scalar the Lipschitz constant of the gradient of f(X).
%       opts   : a struct
%           opts.lambda  : a regularization parameter, can be either a scalar or
%                           a weighted matrix.
%           opts.max_iter: maximum iterations of the algorithm.
%                           Default 300.
%           opts.tol     : a tolerance, the algorithm will stop if difference
%                           between two successive X is smaller than this value.
%                           Default 1e-8.
%           opts.verbose : showing F(X) after each iteration or not.
%                           Default false.
%       cost_F: optional, a function calculating value of F at X
%               via feval(cost_F, X).
%  OUTPUT:
%      X        : solution
%      iter     : number of run iterations
%      min_cost : the achieved cost
% Modifications:
% 6/3/2024: exit criterion based on cost function relative change.
% -------------------------------------------------------------------
% Original author: Tiep Vu, thv102, 4/6/2016
% (http://www.personal.psu.edu/thv102/)
% Modified 26/02/2024 Yuval Zur.
% -------------------------------------
%     opts = initOpts(opts);
if ~isfield(opts, 'max_iter')
    opts.max_iter = 500;
end
if ~isfield(opts, 'regul')
    opts.regul = 'l1';
end

if ~isfield(opts, 'tol')
    opts.tol = 1e-8;
end

if ~isfield(opts, 'verbose')
    opts.verbose = false;
end
Linv = 1/L;
lambdaLiv = opts.lambda*Linv;

x_old = Xinit;
y_old = Xinit;
t_old = 1;
cost_old = 1e10;

%% MAIN LOOP
opts_proj = opts;
opts_proj.lambda = lambdaLiv;
cost_vec = zeros(1, opts.max_iter);
for iter = 1:opts.max_iter
    x_new = SoftThresh(y_old - Linv*feval(grad, y_old), lambdaLiv);
    if iter == 1
        Cst1 = feval(cost_F, x_new);
    end
    t_new = 0.5*(1 + sqrt(1 + 4*t_old^2));
    y_new = x_new + (t_old - 1)/t_new * (x_new - x_old);

    %% show progress
    if opts.verbose
        if nargin ~= 0
            cost_new = feval(cost_F, x_new);
            if cost_new <= cost_old
                stt = 'YES.';
            else
                stt = 'NO, check your code.';
            end
            fprintf('iter = %3d, err = %f, cost decreases? %s\n', ...
                iter, abs(cost_old - cost_new)/cost_old, stt);
            cost_vec(iter) = cost_new;
            cost_old = cost_new;
        else
            if mod(iter, 5) == 0
                fprintf('.');
            end
            if mod(iter, 10) == 0
                fprintf('%d', iter);
            end
        end
    end
    %% check stop criteria
    % e = norm1(x_new - x_old)/numel(x_new);
    if mod(iter, 10) == 0
        % Linv = 1.2*Linv;
        cost2 = feval(cost_F, x_new);
        if cost2 > Cst1
            e = -1;
            break
        end
        cost1 = feval(cost_F, x_old);
        e = abs(cost1 - cost2)/cost1;
        if e < opts.tol
            break;
        end
    end
    %% update
    x_old = x_new;
    t_old = t_new;
    y_old = y_new;
end
cost_vec = cost_vec(1:iter);
X = x_new;
if nargout >= 3
    min_cost = feval(cost_F, X);
end
end
